{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "9s5MSGYOqXrk"
      },
      "source": [
        "# Iterative Patch Selection - Simple Example\n",
        "\n",
        "In this notebook, you can play around with IPS (e.g., by loading it into Google Colab)! For simplicity and illustration purposes, we apply IPS to a subset of the MNIST dataset. Images are resized to 500x500 and 25 patches of size 100x100 are extracted. We use a memory and iteration size of 5 patches, respectively.\n",
        "\n",
        "IPS is applied in the eager loading setting, i.e. the full input batch is directly loaded onto the GPU. Furthermore, the positional encoding is omitted here. For full control over data loading options (eager/eager sequential/lazy), positional encoding, datasets used within the experiments and tracking of computational efficiency metrics, we refer to the more complete code in the Github repository."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jd5kgN2vAv_1"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import random\n",
        "import math\n",
        "from matplotlib import pyplot as plt\n",
        "\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torchvision\n",
        "from torchvision.models import resnet18"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IDLCwNnLxf3r"
      },
      "outputs": [],
      "source": [
        "def set_seed(seed: int = 42) -> None:\n",
        "    np.random.seed(seed)\n",
        "    random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    torch.cuda.manual_seed(seed)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
        "    print(f\"Random seed set as {seed}\")\n",
        "\n",
        "set_seed()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ssQx95lqBbRa"
      },
      "outputs": [],
      "source": [
        "data_dir = '/files/'\n",
        "img_size = 500\n",
        "n_chan = 1\n",
        "n_class = 10\n",
        "patch_size = 100\n",
        "patch_stride = 100\n",
        "batch_size = 128\n",
        "n_epoch = 10\n",
        "\n",
        "M = 5 # Number of selected patches\n",
        "I = 5 # Number of patches concatenated with memory buffer in each iteration\n",
        "D = 512 # Feature dim"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I6Yj14kPeWKc"
      },
      "source": [
        "# Data preparation\n",
        "\n",
        "We resize the MNIST images to 500x500 and create 25 non-overlapping patches of size 100x100. To reduce training time, a random subset of 10,000 training images will be used."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "L8faOhSjBIDo"
      },
      "outputs": [],
      "source": [
        "def patchify(img, patch_size, patch_stride):\n",
        "    patches = img.unfold(\n",
        "        1, patch_size, patch_stride\n",
        "    ).unfold(\n",
        "        2, patch_size, patch_stride\n",
        "    ).permute(1, 2, 0, 3, 4)\n",
        "    patches = patches.reshape(-1, *patches.shape[2:])\n",
        "    return patches\n",
        "\n",
        "transform=torchvision.transforms.Compose([\n",
        "  torchvision.transforms.Resize(img_size),\n",
        "  torchvision.transforms.ToTensor(),\n",
        "  torchvision.transforms.Lambda(lambda x: patchify(x, patch_size, patch_stride))\n",
        "])\n",
        "\n",
        "subset_size = 10000\n",
        "train_set = torchvision.datasets.MNIST(data_dir, train=True, download=True, transform=transform)\n",
        "subset_indices = torch.randperm(len(train_set))[:subset_size]\n",
        "train_set = torch.utils.data.Subset(train_set, subset_indices)\n",
        "train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
        "\n",
        "test_set = torchvision.datasets.MNIST(data_dir, train=True, download=True, transform=transform)\n",
        "test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pXh5XRIlfSkl"
      },
      "source": [
        "# Cross-Attention Transformer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jRGtW6S8Gr4y"
      },
      "outputs": [],
      "source": [
        "class ScaledDotProductAttention(nn.Module):\n",
        "  def __init__(self, temperature, attn_dropout=0.1):\n",
        "    super().__init__()\n",
        "    self.temperature = temperature\n",
        "    self.dropout = nn.Dropout(attn_dropout)\n",
        "\n",
        "  def compute_attn(self, q, k):\n",
        "    attn = torch.matmul(q / self.temperature, k.transpose(2, 3))\n",
        "    attn = self.dropout(torch.softmax(attn, dim=-1))\n",
        "    return attn\n",
        "\n",
        "  def forward(self, q, k, v):\n",
        "    attn = self.compute_attn(q, k)\n",
        "    output = torch.matmul(attn, v)\n",
        "    return output\n",
        "\n",
        "class MultiHeadCrossAttention(nn.Module):\n",
        "  def __init__(self, n_token, H, D, D_k, D_v, attn_dropout, dropout):\n",
        "    super().__init__()\n",
        "    self.n_token = n_token\n",
        "    self.H = H\n",
        "    self.D_k = D_k\n",
        "    self.D_v = D_v\n",
        "\n",
        "    self.q = nn.Parameter(torch.empty((1, n_token, D)))\n",
        "    q_init_val = math.sqrt(1 / D_k)\n",
        "    nn.init.uniform_(self.q, a=-q_init_val, b=q_init_val)\n",
        "\n",
        "    self.q_w = nn.Linear(D, H * D_k, bias=False)\n",
        "    self.k_w = nn.Linear(D, H * D_k, bias=False)\n",
        "    self.v_w = nn.Linear(D, H * D_v, bias=False)\n",
        "    self.fc = nn.Linear(H * D_v, D, bias=False)\n",
        "\n",
        "    self.attention = ScaledDotProductAttention(\n",
        "        temperature=D_k ** 0.5,\n",
        "        attn_dropout=attn_dropout\n",
        "    )\n",
        "\n",
        "    self.dropout = nn.Dropout(dropout)\n",
        "    self.layer_norm = nn.LayerNorm(D, eps=1e-6)\n",
        "\n",
        "  def get_attn(self, x):\n",
        "    D_k, H, n_token = self.D_k, self.H, self.n_token\n",
        "    B, len_seq = x.shape[:2]\n",
        "\n",
        "    q = self.q_w(self.q).view(1, n_token, H, D_k)\n",
        "    k = self.k_w(x).view(B, len_seq, H, D_k)\n",
        "\n",
        "    q, k = q.transpose(1, 2), k.transpose(1, 2)\n",
        "\n",
        "    attn = self.attention.compute_attn(q, k)\n",
        "    return attn\n",
        "\n",
        "  def forward(self, x):\n",
        "    D_k, D_v, H, n_token = self.D_k, self.D_v, self.H, self.n_token\n",
        "    B, len_seq = x.shape[:2]\n",
        "\n",
        "    # project and separate heads\n",
        "    q = self.q_w(self.q).view(1, n_token, H, D_k)\n",
        "    k = self.k_w(x).view(B, len_seq, H, D_k)\n",
        "    v = self.v_w(x).view(B, len_seq, H, D_v)\n",
        "\n",
        "    # transpose for attention dot product: B x H x len_seq x D_k or D_v\n",
        "    q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)\n",
        "    # cross-attention\n",
        "    x = self.attention(q, k, v)\n",
        "\n",
        "    # transpose again: B x n_token x H x D_v\n",
        "    # concat heads: B x n_token x (H * D_v)\n",
        "    x = x.transpose(1, 2).contiguous().view(B, n_token, -1)\n",
        "    # combine heads\n",
        "    x = self.dropout(self.fc(x))\n",
        "    # residual connection + layernorm\n",
        "    x += self.q\n",
        "    x = self.layer_norm(x)\n",
        "    return x\n",
        "\n",
        "class MLP(nn.Module):\n",
        "  def __init__(self, D, D_inner, dropout):\n",
        "    super().__init__()\n",
        "    self.w_1 = nn.Linear(D, D_inner)\n",
        "    self.w_2 = nn.Linear(D_inner, D)\n",
        "    self.layer_norm = nn.LayerNorm(D, eps=1e-6)\n",
        "    self.dropout = nn.Dropout(dropout)\n",
        "\n",
        "  def forward(self, x):\n",
        "    residual = x\n",
        "    x = self.w_2(torch.relu(self.w_1(x)))\n",
        "    x = self.dropout(x)\n",
        "    x += residual\n",
        "    x = self.layer_norm(x)\n",
        "    return x\n",
        "\n",
        "class Transformer(nn.Module):\n",
        "  def __init__(self, n_token=1, H=8, D=512, D_k=64, D_v=64, D_inner=2048, attn_dropout=0.1, dropout=0.1):\n",
        "    super().__init__()\n",
        "    self.crs_attn = MultiHeadCrossAttention(n_token, H, D, D_k, D_v, attn_dropout, dropout)\n",
        "    self.mlp = MLP(D, D_inner, dropout)\n",
        "  \n",
        "  def get_scores(self, x):\n",
        "    attn = self.crs_attn.get_attn(x)\n",
        "    # Average scores over heads and tasks\n",
        "    return attn.mean(dim=1).transpose(1, 2).mean(-1)\n",
        "\n",
        "  def forward(self, x):\n",
        "    return self.mlp(self.crs_attn(x))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "08UUHzvxfa0C"
      },
      "source": [
        "# Iterative Patch Selection\n",
        "Network that runs the full pipeline including patch encoder, IPS, patch aggregation module and classification head. We use a ResNet-18 trained from scratch as patch encoder in this example."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4oB2rc5dCmAG"
      },
      "outputs": [],
      "source": [
        "class IPSNet(nn.Module):\n",
        "  def __init__(self, M, I, D):\n",
        "    super().__init__()\n",
        "    self.M = M\n",
        "    self.I = I\n",
        "    self.D = D\n",
        "\n",
        "    self.patch_encoder = resnet18()\n",
        "    self.patch_encoder.conv1 = nn.Conv2d(n_chan, 64, kernel_size=7, stride=2, padding=3, bias=False)\n",
        "    self.patch_encoder.fc = nn.Identity() #just need features\n",
        "\n",
        "    self.transf = Transformer()\n",
        "    self.head = nn.Sequential(nn.Linear(D, n_class))\n",
        "  \n",
        "  def score_and_select(self, emb, M, idx):\n",
        "    \"\"\"Scores embeddings and selects the top-M embeddings\"\"\"\n",
        "    # Obtain scores from transformer\n",
        "    attn = self.transf.get_scores(emb) # (B, M+I)\n",
        "\n",
        "    # Get indices of top-scoring patches\n",
        "    top_idx = torch.topk(attn, M, dim=-1)[1] # (B, M)\n",
        "    \n",
        "    # Update memory buffers\n",
        "    mem_emb = torch.gather(emb, 1, top_idx.unsqueeze(-1).expand(-1,-1,self.D))\n",
        "    mem_idx = torch.gather(idx, 1, top_idx)\n",
        "\n",
        "    return mem_emb, mem_idx\n",
        "\n",
        "  # IPS runs in no-gradient mode\n",
        "  @torch.no_grad()\n",
        "  def ips(self, patches):\n",
        "    \"\"\"Iterative Patch Selection\"\"\"\n",
        "    patch_shape = patches.shape\n",
        "    B, N = patch_shape[:2]\n",
        "\n",
        "    # Shortcut: IPS not required when memory is larger than total number of patches\n",
        "    if M >= N:\n",
        "      return patches\n",
        "    \n",
        "    # IPS runs in evaluation mode\n",
        "    if self.training:\n",
        "        self.patch_encoder.eval()\n",
        "        self.transf.eval()\n",
        "\n",
        "    # Init memory buffer\n",
        "    init_patch = patches[:,:M]\n",
        "    # Apply patch encoder\n",
        "    mem_emb = self.patch_encoder(init_patch.reshape(-1, *patch_shape[2:]))\n",
        "    mem_emb = mem_emb.view(B, M, -1)\n",
        "    \n",
        "    # Init memory indices in order to select patches at the end of IPS.\n",
        "    idx = torch.arange(N, dtype=torch.int64, device=mem_emb.device).unsqueeze(0).expand(B, -1)\n",
        "    mem_idx = idx[:,:M]\n",
        "\n",
        "    # Apply IPS for `n_iter` iterations\n",
        "    n_iter = math.ceil((N - M) / I)\n",
        "    for i in range(n_iter):\n",
        "        # Get next patches\n",
        "        start_idx = i * I + M\n",
        "        end_idx = min(start_idx + I, N)\n",
        "\n",
        "        iter_patch = patches[:, start_idx:end_idx]\n",
        "        iter_idx = idx[:, start_idx:end_idx]\n",
        "\n",
        "        # Embed patches\n",
        "        iter_emb = self.patch_encoder(iter_patch.reshape(-1, *patch_shape[2:]))\n",
        "        iter_emb = iter_emb.view(B, -1, D)\n",
        "        \n",
        "        # Concatenate with memory buffer\n",
        "        all_emb = torch.cat((mem_emb, iter_emb), dim=1)\n",
        "        all_idx = torch.cat((mem_idx, iter_idx), dim=1)\n",
        "\n",
        "        # Select Top-M patches according to cross-attention scores\n",
        "        mem_emb, mem_idx = self.score_and_select(all_emb, M, all_idx)\n",
        "\n",
        "    # Select patches\n",
        "    n_dim_expand = len(patch_shape) - 2\n",
        "    mem_patch = torch.gather(patches, 1, \n",
        "        mem_idx.view(B, -1, *(1,)*n_dim_expand).expand(-1, -1, *patch_shape[2:])\n",
        "    )\n",
        "\n",
        "    # Set components back to training mode\n",
        "    # Although components of `self` that are relevant for IPS have been set to eval mode,\n",
        "    # self is still in training mode at training time, i.e., we can use it here.\n",
        "    if self.training:\n",
        "        self.patch_encoder.train()\n",
        "        self.transf.train()\n",
        "\n",
        "    # Return selected patches\n",
        "    return mem_patch, mem_idx\n",
        "\n",
        "  def forward(self, mem_patch):\n",
        "    \"\"\"\n",
        "    After M patches have been selected during IPS, encode and aggregate them.\n",
        "    The aggregated embedding is input to a classification head.\n",
        "    \"\"\"\n",
        "    patch_shape = mem_patch.shape\n",
        "    B, M = patch_shape[:2]\n",
        "    \n",
        "    mem_emb = self.patch_encoder(mem_patch.reshape(-1, *patch_shape[2:]))\n",
        "    mem_emb = mem_emb.view(B, M, -1)        \n",
        "\n",
        "    image_emb = self.transf(mem_emb).squeeze(1)\n",
        "    pred = self.head(image_emb)\n",
        "    return pred\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LSZdYJo7gVll"
      },
      "source": [
        "# Training loop"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yGp_Pvl8DoqA"
      },
      "outputs": [],
      "source": [
        "net = IPSNet(M, I, D).cuda()\n",
        "loss_fn = nn.CrossEntropyLoss()\n",
        "optimizer = torch.optim.AdamW(net.parameters(), lr=1e-3, weight_decay=0.1)\n",
        "\n",
        "for epoch in range(n_epoch):\n",
        "  net.train()\n",
        "\n",
        "  losses = []\n",
        "  correct, total = 0, 0\n",
        "  for data in train_loader:\n",
        "    image_patches, labels = data[0].cuda(), data[1].cuda()\n",
        "    mem_patch, mem_idx = net.ips(image_patches)\n",
        "\n",
        "    optimizer.zero_grad()\n",
        "\n",
        "    preds = net(mem_patch)\n",
        "    loss = loss_fn(preds, labels)\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "\n",
        "    losses.append(loss.item())\n",
        "    \n",
        "    # Aggregate for accuracy metric\n",
        "    y_pred = torch.argmax(preds, dim=-1)\n",
        "    correct += (y_pred == labels).sum()\n",
        "    total += labels.shape[0]\n",
        "  \n",
        "  mean_loss = np.mean(losses)\n",
        "  accuracy = correct / total\n",
        "  print(f\"Epoch: {epoch}, train loss: {mean_loss}, accuracy: {accuracy}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Qx17zLsE6o7S"
      },
      "source": [
        "# Test loop"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cMGQwFss6rxM"
      },
      "outputs": [],
      "source": [
        "#Evaluation on the full test set takes some time on Google Colab's standard GPU and is omitted by default.\n",
        "evaluate=False \n",
        "\n",
        "if evaluate:\n",
        "  net.eval()\n",
        "\n",
        "  losses = []\n",
        "  correct, total = 0, 0\n",
        "  with torch.no_grad():\n",
        "    for data in test_loader:\n",
        "      image_patches, labels = data[0].cuda(), data[1].cuda()\n",
        "      mem_patch, mem_idx = net.ips(image_patches)\n",
        "\n",
        "      preds = net(mem_patch)\n",
        "      loss = loss_fn(preds, labels)\n",
        "      losses.append(loss.item())\n",
        "      \n",
        "      y_pred = torch.argmax(preds, dim=-1)\n",
        "      correct += (y_pred == labels).sum()\n",
        "      total += labels.shape[0]\n",
        "\n",
        "  mean_loss = np.mean(losses)\n",
        "  accuracy = correct / total\n",
        "  print(f\"test loss: {mean_loss}, accuracy: {accuracy}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SPfpoGgdgjcD"
      },
      "source": [
        "# Visualization\n",
        "\n",
        "We pick some random images and calculate and visualize attention scores for them."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qOBXWx2O80Di"
      },
      "outputs": [],
      "source": [
        "n_img_vis = 16\n",
        "rand_idx = torch.randperm(image_patches.shape[0])[:n_img_vis]\n",
        "\n",
        "# Select random patches and corresponding patch embeddings and indices\n",
        "mem_patch = mem_patch[rand_idx]\n",
        "mem_idx = mem_idx[rand_idx]\n",
        "patches = image_patches[rand_idx]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aXkoNl8wg2Hc"
      },
      "outputs": [],
      "source": [
        "# Obtain attention scores for selected patches\n",
        "net.eval()\n",
        "with torch.no_grad():\n",
        "  mem_emb = net.patch_encoder(mem_patch.reshape(-1, *mem_patch.shape[2:]))\n",
        "  mem_emb = mem_emb.view(n_img_vis, M, D)\n",
        "  attn_scores = net.transf.get_scores(mem_emb).cpu()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Se2fr7a_hCxC"
      },
      "outputs": [],
      "source": [
        "# Populate attention score map\n",
        "n_patches_per_side = int(img_size/patch_size)\n",
        "attn_map = torch.zeros((n_img_vis, n_chan, img_size, img_size))\n",
        "for i in range(n_img_vis):\n",
        "  for it, j in enumerate(mem_idx[i]):\n",
        "    row_id = (j // n_patches_per_side) * patch_stride\n",
        "    col_id = (j % n_patches_per_side) * patch_stride\n",
        "    attn_map[i, 0, row_id:row_id+patch_size, col_id:col_id+patch_size] += attn_scores[i, it]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Eh9twg7BhK5y"
      },
      "outputs": [],
      "source": [
        "# Fold patches back to image\n",
        "patches = patches.cpu().permute(0, 2, 1, 3, 4)\n",
        "patches = patches.view(n_img_vis, 1, n_patches_per_side**2, patch_size**2)\n",
        "patches = patches.permute(0, 1, 3, 2)\n",
        "patches = patches.reshape(n_img_vis, n_chan * patch_size**2, -1)\n",
        "images = F.fold(patches, output_size=(img_size, img_size), kernel_size=patch_size, stride=patch_stride)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VKcoF02GnTRE"
      },
      "outputs": [],
      "source": [
        "# Plot images and corresponding attention maps\n",
        "n_row = int(math.sqrt(n_img_vis))\n",
        "n_col = n_row\n",
        "\n",
        "fig = plt.figure(dpi=300, figsize=(19.20,9.83))\n",
        "for i, (img, attn) in enumerate(zip(images, attn_map)):\n",
        "  img = img.numpy()\n",
        "  img = (img - np.min(img))/np.ptp(img)\n",
        "  img = np.transpose(img, (1, 2, 0))\n",
        "\n",
        "  attn = attn.numpy()\n",
        "  attn = (attn - np.min(attn)) / np.ptp(attn)\n",
        "  attn = np.transpose(attn, (1, 2, 0))\n",
        "\n",
        "  ax = fig.add_subplot(n_row, n_col, i+1)      \n",
        "  ax.imshow(img.squeeze(), cmap='gray')\n",
        "  ax.imshow(attn.squeeze(), alpha=0.5)\n",
        "\n",
        "plt.show()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "pytorch_base38",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]"
    },
    "vscode": {
      "interpreter": {
        "hash": "cf80549a660a2b56efb0fe6bb9ed73f7efa081459207ca0e40a408447ec2df77"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
